{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "db6e3b87",
   "metadata": {},
   "source": [
    "# [7、深入剖析PyTorch nn.Module源码](https://www.bilibili.com/video/BV1ch411b7yP?spm_id_from=333.788.player.switch&vd_source=cdd897fffb54b70b076681c3c4e4d45d)\n",
    "\n",
    "其感觉也不深入\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d4197110",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.nn import Module\n",
    "from torchvision import datasets\n",
    "from torchvision.transforms import ToTensor\n",
    "from torch.utils.data import Dataset, DataLoader"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0d4e16c0",
   "metadata": {},
   "source": [
    "## nn.Module\n",
    "\n",
    "基本类，自定义模型都需要继承此类\n",
    "\n",
    "- add_module：增加模块，可以通过 . 来访问\n",
    "- apply：使用一个函数应用到所有子模块，一般由于初始化参数\n",
    "- bfloat16：浮点类型都转换为 bf16 低精度\n",
    "- buffers：返回模块buffer的迭代器，神经网络中参数用来随机梯度下降，buffer不参与。例如BatchNorm\n",
    "- eval：设置为evaluate模式，其实就是 train(False)\n",
    "- load_state_dict：从磁盘保存模型的参数 （torch.save的）\n",
    "- named_parameters：返回所有模型参数\n",
    "- requires_grad_：是否需要梯度更新（在GAN、微调中用到）"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b735fec3",
   "metadata": {},
   "source": [
    "## apply"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "962fd94e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Linear(in_features=2, out_features=2, bias=True)\n",
      "Parameter containing:\n",
      "tensor([[1., 1.],\n",
      "        [1., 1.]], requires_grad=True)\n",
      "Linear(in_features=2, out_features=2, bias=True)\n",
      "Parameter containing:\n",
      "tensor([[1., 1.],\n",
      "        [1., 1.]], requires_grad=True)\n",
      "Sequential(\n",
      "  (0): Linear(in_features=2, out_features=2, bias=True)\n",
      "  (1): Linear(in_features=2, out_features=2, bias=True)\n",
      ")\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "Sequential(\n",
       "  (0): Linear(in_features=2, out_features=2, bias=True)\n",
       "  (1): Linear(in_features=2, out_features=2, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# apply：使用一个函数应用到所有子模块，一般由于初始化参数\n",
    "@torch.no_grad()\n",
    "def init_weights(m):\n",
    "    print(m)\n",
    "    if type(m) == nn.Linear:\n",
    "        m.weight.fill_(1.0)\n",
    "        print(m.weight)\n",
    "\n",
    "\n",
    "net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))\n",
    "net.apply(init_weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c05a923",
   "metadata": {},
   "source": [
    "## `register_buffer`\n",
    "\n",
    "有些模型的数据不参与梯度更新，但是是模型的内容，因此需要设置为buffer不参与梯度计算。\n",
    "\n",
    "- `self._buffers[name] = tensor`\n",
    "\n",
    "例如 BatchNorm 中均值是统计信息。"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5c71c61",
   "metadata": {},
   "source": [
    "## `register_parameter`\n",
    "\n",
    "注册参数，会被更新。\n",
    "- `self._parameters[name] = param`\n",
    "\n",
    "param类型的变量，如果用在 Module 中，会被自动加入_parameters表中"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "41e7ac3b",
   "metadata": {},
   "source": [
    "## `get_submodule`\n",
    "\n",
    "```txt\n",
    "A(\n",
    "    (net_b): Module(\n",
    "        (net_c): Module(\n",
    "            (conv): Conv2d(16, 33, kernel_size=(3, 3), stride=(2, 2))\n",
    "        )\n",
    "        (linear): Linear(in_features=100, out_features=200, bias=True)\n",
    "    )\n",
    ")\n",
    "```\n",
    "\n",
    "存在嵌套结构，使用`.`：`get_submodule(\"net_b.net_c.conv\")`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d0ab6169",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "dl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
